Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move compilation of CUDA code to NVRTC #131

Merged
merged 153 commits into from
Oct 17, 2024
Merged

Move compilation of CUDA code to NVRTC #131

merged 153 commits into from
Oct 17, 2024

Conversation

nickjbrowning
Copy link
Collaborator

No description provided.

Copy link

github-actions bot commented Jul 30, 2024

Here is a pre-built version of the code in this pull request: wheels.zip, you can install it locally by unzipping wheels.zip and using pip to install the file matching your system

@nickjbrowning
Copy link
Collaborator Author

I also see some compilation warnings:

      In file included from /home/filippo/code/sphericart/sphericart-jax/sphericart/src/cuda_base.cpp:4:
      /home/filippo/code/sphericart/sphericart-jax/sphericart/include/cuda_cache.hpp: In member function ‘void CachedKernel::checkAndAdjustSharedMem(int)’:
      /home/filippo/code/sphericart/sphericart-jax/sphericart/include/cuda_cache.hpp:183:22: warning: unused variable ‘res’ [-Wunused-variable]
        183 |             CUresult res = driver.cuCtxGetDevice(&cuDevice);
            |                      ^~~
      /home/filippo/code/sphericart/sphericart-jax/sphericart/src/cuda_base.cpp: In function ‘size_t total_buffer_size(size_t, size_t, size_t, size_t, bool, bool)’:
      /home/filippo/code/sphericart/sphericart-jax/sphericart/src/cuda_base.cpp:22:12: warning: unused parameter ‘GRID_DIM_X’ [-Wunused-parameter]
         22 |     size_t GRID_DIM_X,
            |     ~~~~~~~^~~~~~~~~~

this should be fixed now.

.clang-format Outdated Show resolved Hide resolved
static CacheMapCUDA<C, T> sph_cache;
static std::mutex cache_mutex;

// Check if instance exists in cache, if not create and store it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frostedoyster should we purge the instances after some point? Or do we let this map grow unbounded? (This would be for a separate PR!)

sphericart/CMakeLists.txt Outdated Show resolved Hide resolved
sphericart/include/dynamic_cuda.hpp Outdated Show resolved Hide resolved
sphericart/src/sphericart_cuda.cpp Outdated Show resolved Hide resolved
sphericart/include/dynamic_cuda.hpp Outdated Show resolved Hide resolved
sphericart/src/sphericart_cuda.cpp Outdated Show resolved Hide resolved
Comment on lines +107 to +110
# Install this one manually. Listing it in the deps list above does not install jaxlib.
# Note: jax[cuda12] is not available on Windows and MacOS.
bash -c 'command -v nvcc && python -m pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html || python -m pip install -U jax[cpu]'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need the CUDA version of JAX to run the examples?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the immediate downside and it does allow for catching more bugs on the CSCS CI

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, it won't run the cuda code through jax, and we woul have indeed missed the bugs from last week

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there is a separate test section above which is already using cuda, this is for the examples. If there is stuff in there that acts like a test, we should maybe move it to the tests instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is fine for now, we can revisit later if this becomes a problem.

Co-authored-by: Guillaume Fraux <[email protected]>
Comment on lines 20 to 22
"Failed to load libcuda.so. Try running \"find /usr -name libcuda.so\" and "
"appending the directory to your $LD_LIBRARY_PATH environment variable."
);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, nice that this is now the only place with an error mesage! Could it be updated to not reference /usr since libcuda can be elsewhere? This was already done on the other messages but these fell through!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated it, do you agree with the new message?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, looks a lot better!

Copy link
Contributor

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your work @nickjbrowning!

@frostedoyster frostedoyster merged commit 877370a into main Oct 17, 2024
10 checks passed
@frostedoyster frostedoyster deleted the nvrtc branch October 17, 2024 20:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants